import time
import websocket
from decimal import Decimal
from threading import Thread, Lock
from collections import deque
from quant.markets import functions
from quant.accounts import UserData
from quant.utils import logging, catch_exception, except_caught_fn, Iota, Saving, LimitDict, Timer2
from quant.const import *


class Socket:
    instances = []

    def __init__(self, host, open_fn, msg_fn, close_fn, pre_connect_fn=None):
        self.open_fn = except_caught_fn(open_fn)
        self.msg_fn = except_caught_fn(msg_fn)
        self.close_fn = except_caught_fn(close_fn)
        if pre_connect_fn:
            pre_connect_fn = except_caught_fn(pre_connect_fn)
        self.pre_connect_fn = pre_connect_fn

        self._keep_running = False
        self._is_open = False
        self._last_msg_time = 0

        self._socket = websocket.WebSocketApp(host, None, self._on_open, self._on_message, self._on_error, self._on_close)
        self._threading = Thread(target=self._run, daemon=True)
        self._threading.start()
        self.instances.append(self)

    def connect(self):
        self._keep_running = True

    def reconnect(self):
        self._socket.close()

    def disconnect(self):
        self._keep_running = False
        self._socket.close()

    def is_connected(self):
        return self._is_open

    def set_host(self, host):
        self._socket.url = host

    def register_ping(self, interval, build_ping):
        def on_timer():
            if self.is_connected():
                msg = build_ping()
                self.send(msg)
        Timer2(on_timer, interval).start()

    def send(self, msg):
        self._socket.send(msg)

    def assert_send(self, msg, timeout=3):
        if self.is_connected():
            self.send(msg)
            return
        Thread(target=self._try_send, args=(msg, timeout), daemon=True).start()

    def get_host(self):
        return self._socket.url

    def _run(self):
        while True:
            if not self._keep_running:
                time.sleep(1)
                continue

            if self.pre_connect_fn:
                self.pre_connect_fn(self._socket)

            try:
                self._socket.run_forever(skip_utf8_validation=True)
            except:
                catch_exception()
            time.sleep(5)

    def _on_open(self, ws):
        logging.info('Socket connected: {}'.format(self._socket.url))
        self._is_open = True
        self.open_fn(ws)

    def _on_message(self, ws, msg):
        self._last_msg_time = time.time()
        with socket_lock:
            self.msg_fn(ws, msg)

    def _on_error(self, ws, err):
        err = 'Socket.Error "{}" on {}'.format(err, self._socket.url)
        logging.error(err)

    def _on_close(self, ws, close_status_code, close_msg):
        logging.info('Socket closed: {} {} on {}'.format(close_status_code, close_msg, self._socket.url))
        self._is_open = False
        with socket_lock:
            self.close_fn(ws)

    def _try_send(self, msg, timeout):
        t = 0
        while t <= timeout:
            time.sleep(1)
            t += 1
            if self.is_connected():
                self.send(msg)
                return
        logging.error('Socket.assert_send({}) fail'.format(msg))


class AscendingChecker:
    def __init__(self):
        self.last = float('-inf')

    def is_new(self, data_id):
        if data_id > self.last:
            self.last = data_id
            return True
        return False


class RecentChecker:
    def __init__(self, history_count):
        self.limit_dict = LimitDict(history_count)

    def is_new(self, data_id):
        if data_id in self.limit_dict:
            return False
        self.limit_dict[data_id] = data_id
        return True


class RecentChecker2:
    def __init__(self, history_count):
        self._lock = Lock()
        self._queue = deque()
        self._set = set()
        self._history_count = history_count

    def __repr__(self):
        result = '------recent_checker_2------'
        a = sorted(self._queue)
        b = sorted(self._set)
        result = '{}\n{}\n{}'.format(result, a, b)
        return result

    def is_new(self, data_id):
        _queue = self._queue
        _set = self._set

        with self._lock:
            if data_id in _set:
                return False

            _queue.append(data_id)
            _set.add(data_id)

            while len(_queue) > self._history_count:
                i = _queue.popleft()
                _set.remove(i)

            return True


class _AllIncludeSet:
    def __contains__(self, item):
        return True


class NullCls:
    def __init__(self, *args, **kwargs):
        pass

    def __getattr__(self, item):
        return self._func

    def _func(self, *args, **kwargs):
        pass


def get_agent_name(symbol):
    if '.' in symbol:
        agent_name = symbol.split('/')[-1]
    else:
        agent_name = 'spot'
    return agent_name


def create_md_id():
    return _md_iota.next()


def parse_params(params):
    parsed = '&'.join([f'{key}={value}' for key, value in params.items()])
    return parsed


def parse_params_sorted(params):
    params = ((k, v) for k, v in params.items())
    params = sorted(params, key=_sort_param_key)
    parsed = '&'.join([f'{k}={v}' for k, v in params])
    return parsed


def _sort_param_key(x):
    return x[0]



def check_book_cross(book):
    buy_1, _ = book.item('buy', 1)
    sell_1, _ = book.item('sell', 1)

    try:
        result = buy_1 >= sell_1
    except TypeError:
        result = False

    if result:
        book['buy', buy_1] = 0
        book['sell', sell_1] = 0

    return result


def next_cid_head():
    return _get_cid_head_1() + _get_cid_head_2()


def _get_cid_head_1():
    name = 'cid_head_ord'
    saving = Saving(name)
    try:
        this = saving.get(name)
    except KeyError:
        this = 122

    this += 1
    if this > 122:
        this = 97
    saving.set(name, this)

    word = chr(this)
    return word


def _get_cid_head_2():
    name = 'cid_head_ord_2'
    saving = Saving(name)
    try:
        this = saving.get(name)
    except KeyError:
        this = 122

    this += 1
    if this > 122:
        this = 98
    saving.set(name, this)

    word = chr(this)
    return word


def int_prec_to_precision(n):
    if n < 1:
        return '1'

    zeros = n - 1
    result = '0.{}1'.format('0'*zeros)
    return result


def spot_value_factor(symbol):
    asset, _ = symbol.split('/')
    return functions.get_asset_value(asset)


def simulate_deal_spot(order, deal_amount, mid_price, account):
        symbol = order.symbol
        balance = account.balance
        asset_1, asset_0 = symbol.split('.')[0].split('/')
        side = order.side

        delta_1 = deal_amount
        delta_0 = order.price * deal_amount
        if side == 'buy':
            delta_0 = -delta_0
        else:
            delta_1 = -delta_1

        if asset_0 not in balance:
            balance[asset_0] = {'free': 0, 'frozen': 0}
        if asset_1 not in balance:
            balance[asset_1] = {'free': 0, 'frozen': 0}

        balance[asset_1]['free'] += delta_1
        balance[asset_0]['free'] += delta_0

        user_data = UserData(UserEvent.Balance, ApiType.Spi, balance)
        account.info_engine.put_user_data(user_data)
        account.routing_engine.put_user_data(user_data)


def simulate_deal_usdt_swap(order, deal_amount, mid_price, account, size_factor=1):  # 撮合不调size，估值调size
        symbol = order.symbol
        sim_balance = account.simulate_balance
        asset_1, asset_0 = symbol.split('.')[0].split('/')
        side = order.side

        delta_1 = deal_amount
        delta_0 = order.price * deal_amount
        if side == 'buy':
            delta_0 = -delta_0
        else:
            delta_1 = -delta_1

        sim_balance[asset_1] = sim_balance.get(asset_1, 0) + delta_1
        sim_balance[asset_0] = sim_balance.get(asset_0, 0) + delta_0

        profit = (sim_balance[asset_1]*mid_price + sim_balance[asset_0]) * size_factor
        account.margin[asset_0] = account.begin_margin.get(asset_0, 0) + profit

        position = account.position
        position = simulate_update_position(position, symbol, side, deal_amount)

        user_data = UserData(UserEvent.Position, ApiType.Spi, position)
        account.info_engine.put_user_data(user_data)
        account.routing_engine.put_user_data(user_data)

        user_data = UserData(UserEvent.Margin, ApiType.Spi, account.margin)
        account.info_engine.put_user_data(user_data)
        account.routing_engine.put_user_data(user_data)


def simulate_deal_usdt_swap_2(order, deal_amount, mid_price, account, size_factor=1):  # 撮合调size
        symbol = order.symbol
        sim_balance = account.simulate_balance
        asset_1, asset_0 = symbol.split('.')[0].split('/')
        side = order.side

        sized_deal_amount = deal_amount * size_factor
        delta_1 = sized_deal_amount
        delta_0 = order.price * sized_deal_amount
        if side == 'buy':
            delta_0 = -delta_0
        else:
            delta_1 = -delta_1

        sim_balance[asset_1] = sim_balance.get(asset_1, 0) + delta_1
        sim_balance[asset_0] = sim_balance.get(asset_0, 0) + delta_0

        profit = sim_balance[asset_1]*mid_price + sim_balance[asset_0]
        account.margin[asset_0] = account.begin_margin.get(asset_0, 0) + profit

        position = account.position
        position = simulate_update_position(position, symbol, side, deal_amount)

        user_data = UserData(UserEvent.Position, ApiType.Spi, position)
        account.info_engine.put_user_data(user_data)
        account.routing_engine.put_user_data(user_data)

        user_data = UserData(UserEvent.Margin, ApiType.Spi, account.margin)
        account.info_engine.put_user_data(user_data)
        account.routing_engine.put_user_data(user_data)


def simulate_update_position(position, symbol, side, amount):
    if symbol not in position:
        position[symbol] = {'long': 0, 'short': 0, 'position': 0}
    pos = position[symbol]

    if side == 'buy':
        p = pos['position'] + amount
    else:
        p = pos['position'] - amount

    pos['position'] = p
    if p > 0:
        pos['long'] = p
        pos['short'] = 0
    else:
        pos['long'] = 0
        pos['short'] = -p
    return position


def update_balance_default(balance, asset, free=None, frozen=None):
    if asset in balance:
        bal = balance[asset]
    else:
        bal = balance.setdefault(asset, {'free': 0, 'frozen': 0})
    if free is not None:
        bal['free'] = free
    if frozen is not None:
        bal['frozen'] = frozen


def prepare_placing(order, client_id):
    order = order.copy()
    assert order.client_id is None
    if order.order_type is None:
        order.order_type = OrderType.Limit

    order.orig_amount = order.amount
    order.operate_now()
    order.status = OrderStatus.Placing
    order.client_id = client_id
    return order


def prepare_modifying(new, client_id):  # orig_amount different
    new = new.copy()
    assert new.client_id is None
    if new.order_type is None:
        new.order_type = OrderType.Limit

    if new.orig_amount is None:
        new.orig_amount = new.amount

    new.operate_now()
    new.status = OrderStatus.Placing
    new.client_id = client_id
    return new


def prepare_canceling(order):
    order = order.copy()
    order.operate_now()
    order.status = OrderStatus.Canceling
    return order


def avoid_sci_num(price):
    if price is None:
        return None

    str_p = str(price)
    if 'e' not in str_p:
        return price

    digit = int(str_p.split('-')[-1])
    str_p = str_p.split('e')[0]
    str_p = str_p.replace('.', '')
    zeros = '0' * (digit - 1)
    str_p = '0.{}{}'.format(zeros, str_p)
    return str_p


socket_lock = Lock()
all_include_set = _AllIncludeSet()


_md_iota = Iota()


if __name__ == '__main__':
    print(avoid_sci_num(0.0000001))












